import jieba
import torch
from nn.attention import Attention
from nn.encoder import Encoder
from nn.decoder import Decoder
from nn.seq2seq import Seq2Seq
import torch.nn as nn
import torch.optim as optim
import config
from lib.data_center import DataCenter
import lib.util as util
import numpy as np
from torch.utils.data import DataLoader
from lib.dataset import MyDataSet, Seq2seqDataSet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main():
    # test_data = [{
    #     "sentence": "你在干什么"
    # }]
    real_data = input("输入中文\n")
    trg_mock_data = ''
    data_center = DataCenter()

    token2idx = data_center.load_token()
    attn = Attention(config.ENC_HID_DIM, config.DEC_HID_DIM)
    enc = Encoder(len(token2idx), config.ENC_EMB_DIM, config.ENC_HID_DIM, config.DEC_HID_DIM, config.ENC_DROPOUT)
    dec = Decoder(len(token2idx), config.DEC_EMB_DIM, config.ENC_HID_DIM, config.DEC_HID_DIM, config.DEC_DROPOUT,
                  attn)
    model = Seq2Seq(enc, dec, device).to(device)
    model.eval()
    try:
        model.load_state_dict(torch.load(config.seq2seq_model))
    except Exception as e:
        print("读取模型出错", e)
        torch.save(model.state_dict(), config.seq2seq_model)

    test_transferred = data_center.sentence2vec(real_data, token2idx)
    trg_mock = data_center.sentence2vec(trg_mock_data, token2idx)

    if len(test_transferred) < config.LONG_LIMIT:
        test_transferred = torch.cat(
            [test_transferred, torch.LongTensor([0] * (config.LONG_LIMIT - len(test_transferred)))])

    if len(trg_mock) < config.LONG_LIMIT:
        trg_mock = torch.cat([trg_mock, torch.LongTensor([0] * (config.LONG_LIMIT - len(trg_mock)))])

    # trg_mock = data_center.generate_data(trg_mock)
    # test_transferred = data_center.generate_data(test_transferred)
    test_transferred = torch.LongTensor([test_transferred.numpy().tolist()])
    print("src - >", test_transferred)

    trg_mock = torch.LongTensor([trg_mock.numpy().tolist()])
    print("src.shape = ", test_transferred.shape)
    res = model(test_transferred.cuda(), trg_mock.cuda())
    print("res = ", res)


main()
